package com.zjj.lbw.ai;

import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.lang3.StringUtils;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtils;
import org.jfree.chart.JFreeChart;
import org.jfree.data.general.DefaultPieDataset;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;

import java.io.IOException;
import java.io.OutputStream;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 模拟DbChat 根据自然语言生成提示词
 */
@Controller
public class SqlController {

    @Resource
    private ChatModel chatModel;

    @Resource
    private JdbcTemplate jdbcTemplate;

    private final String FILTER_INSTRUCTION = """
    你需要根据指定的Input从Instruction中筛选出最相关的表信息（可能是单个表或多个表），
    首先，我将给你展示一个示例，Instruction后面跟着Input和对应的Response，
    然后，我会给你一个新的Instruction和新的Input，你需要生成一个新的Response来完成任务。

    ### Example1 Instruction:
    job(id, name, age), user(id, name, age), student(id, name, age, info)
    ### Example1 Input:
    Find the age of student table
    ### Example1 Response:
    student(id, name, age, info)
    ###New Instruction:
    {instruction}
    ###New Input:
    {input}
    ###New Response:
    """;

    private final String GENERATE_INSTRUCTION = """
    你扮演一个SQL终端，您只需要返回SQL命令给我，而不需要返回其他任何字符。下面是一个描述任务的Instruction，返回适当的结果完成Input对应的请求.
    ###Instruction:
    {instruction}
    ###Input:
    {input}
    ###Response:
    """;

    @GetMapping("/sql/chat")
    public void chat(@RequestParam("query") String query, HttpServletResponse response) throws SQLException, IOException {

        Map<String, List<String>> tableInfo = getTableInfo();

        List<String> tableInfoList = tableInfo.entrySet().stream()
        .map(entry -> String.format("%s(%s)", entry.getKey(), StringUtils.join(entry.getValue(), ",")))
        .toList();

        String tableInfoPrompt = StringUtils.join(tableInfoList, ",");

        PromptTemplate filtePromptTemplate = new PromptTemplate(FILTER_INSTRUCTION);
        filtePromptTemplate.add("instruction", tableInfoPrompt);
        filtePromptTemplate.add("input", query);

        String filterPrompt = filtePromptTemplate.render();
        String filterResult = chatModel.call(filterPrompt);

        PromptTemplate generatePromptTemplate = new PromptTemplate(GENERATE_INSTRUCTION);
        generatePromptTemplate.add("instruction", filterResult);
        generatePromptTemplate.add("input", query);

        String generatePrompt = generatePromptTemplate.render();
        String sql = chatModel.call(generatePrompt);
        sql = sql.replace("```sql", "");
        sql = sql.replace("```", "");
        System.out.println(sql);

        List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql);

        DefaultPieDataset dataset = new DefaultPieDataset();
        for (Map<String, Object> map : maps) {
            Object[] values = map.values().toArray();
            dataset.setValue(values[0].toString(), Integer.valueOf(values[1].toString()));
        }

        // 创建JFreeChart对象
        JFreeChart chart = ChartFactory.createPieChart(
                "统计结果", // 图标题
                dataset, // 数据集
                false,
                true,
                true);

        // 设置响应类型为图片
        response.setContentType("image/png");
        OutputStream out = response.getOutputStream();

        // 将图表输出到响应的输出流中
        ChartUtils.writeChartAsPNG(out, chart, 800, 600);
        out.flush();
    }

    public Map<String, List<String>> getTableInfo() throws SQLException {
        // 获取数据库的元数据信息
        DatabaseMetaData metaData = jdbcTemplate.getDataSource().getConnection().getMetaData();
        ResultSet tables = metaData.getTables(null, null, "%", new String[]{"TABLE"});

        Map<String, List<String>> result = new HashMap<>();
        while (tables.next()) {
            String tableName = tables.getString("TABLE_NAME");

            ResultSet columns = metaData.getColumns(null, null, tableName, null);
            ArrayList<String> columnNames = new ArrayList<>();
            while (columns.next()) {
                String columnName = columns.getString("COLUMN_NAME");
                String remarks = columns.getString("REMARKS");
                columnNames.add(String.format("%s(%s)", columnName, remarks));
            }

            result.put(tableName, columnNames);
        }

        return result;
    }
}
